{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ceab8707",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b68c35c3",
   "metadata": {},
   "source": [
    "# Self-Supervised Learning Using 3D CT Volumes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95fc8d67",
   "metadata": {},
   "source": [
    "##### Note:\n",
    "[TCIA Covid 19](https://wiki.cancerimagingarchive.net/display/Public/CT+Images+in+COVID-19) dataset needs to be downloaded for the execution of this notebook, both datasets need to be downloaded. Downloading data to the `data_root` path defined variable is optimal, if data is downloaded at another location please change the path `data_root` accordingly.\n",
    "\n",
    "Please see the Readme for more details about the dataset. Data splits have already been provided."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "707541a2",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4dc0237b",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[pillow, tqdm]\"\n",
    "!python -c \"import matplotlib\" || pip install -q matplotlib\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49070e05",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf64bf41",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import time\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from torch.nn import L1Loss\n",
    "from monai.utils import set_determinism, first\n",
    "from monai.networks.nets import ViTAutoEnc\n",
    "from monai.losses import ContrastiveLoss\n",
    "from monai.data import DataLoader, Dataset\n",
    "from monai.config import print_config\n",
    "from monai.transforms import (\n",
    "    LoadImaged,\n",
    "    Compose,\n",
    "    CropForegroundd,\n",
    "    CopyItemsd,\n",
    "    SpatialPadd,\n",
    "    EnsureChannelFirstd,\n",
    "    Spacingd,\n",
    "    OneOf,\n",
    "    ScaleIntensityRanged,\n",
    "    RandSpatialCropSamplesd,\n",
    "    RandCoarseDropoutd,\n",
    "    RandCoarseShuffled,\n",
    ")\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72e2e12c",
   "metadata": {},
   "source": [
    "##### Define file paths & output directory path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "657217e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "json_path = os.path.normpath(\"./datalists/tcia/dataset_split.json\")\n",
    "data_root = os.path.normpath(\"/workspace/datasets/tcia/\")\n",
    "logdir_path = os.path.normpath(\"/to/be/defined\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7adf9d64",
   "metadata": {},
   "source": [
    "##### Create result logging directories, manage data paths & set determinism"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f084405c",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(logdir_path) is False:\n",
    "    os.mkdir(logdir_path)\n",
    "\n",
    "with open(json_path, \"r\") as json_f:\n",
    "    json_data = json.load(json_f)\n",
    "\n",
    "train_data = json_data[\"training\"]\n",
    "val_data = json_data[\"validation\"]\n",
    "\n",
    "for idx, _each_d in enumerate(train_data):\n",
    "    train_data[idx][\"image\"] = os.path.join(data_root, train_data[idx][\"image\"])\n",
    "\n",
    "for idx, _each_d in enumerate(val_data):\n",
    "    val_data[idx][\"image\"] = os.path.join(data_root, val_data[idx][\"image\"])\n",
    "\n",
    "print(\"Total Number of Training Data Samples: {}\".format(len(train_data)))\n",
    "print(train_data)\n",
    "print(\"#\" * 10)\n",
    "print(\"Total Number of Validation Data Samples: {}\".format(len(val_data)))\n",
    "print(val_data)\n",
    "print(\"#\" * 10)\n",
    "\n",
    "# Set Determinism\n",
    "set_determinism(seed=123)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d106d4ea",
   "metadata": {},
   "source": [
    "##### Define MONAI Transforms "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ebbdd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Training Transforms\n",
    "train_transforms = Compose(\n",
    "    [\n",
    "        LoadImaged(keys=[\"image\"]),\n",
    "        EnsureChannelFirstd(keys=[\"image\"]),\n",
    "        Spacingd(keys=[\"image\"], pixdim=(2.0, 2.0, 2.0), mode=(\"bilinear\")),\n",
    "        ScaleIntensityRanged(\n",
    "            keys=[\"image\"],\n",
    "            a_min=-57,\n",
    "            a_max=164,\n",
    "            b_min=0.0,\n",
    "            b_max=1.0,\n",
    "            clip=True,\n",
    "        ),\n",
    "        CropForegroundd(keys=[\"image\"], source_key=\"image\", allow_smaller=True),\n",
    "        SpatialPadd(keys=[\"image\"], spatial_size=(96, 96, 96)),\n",
    "        RandSpatialCropSamplesd(keys=[\"image\"], roi_size=(96, 96, 96), random_size=False, num_samples=2),\n",
    "        CopyItemsd(keys=[\"image\"], times=2, names=[\"gt_image\", \"image_2\"], allow_missing_keys=False),\n",
    "        OneOf(\n",
    "            transforms=[\n",
    "                RandCoarseDropoutd(\n",
    "                    keys=[\"image\"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32\n",
    "                ),\n",
    "                RandCoarseDropoutd(\n",
    "                    keys=[\"image\"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64\n",
    "                ),\n",
    "            ]\n",
    "        ),\n",
    "        RandCoarseShuffled(keys=[\"image\"], prob=0.8, holes=10, spatial_size=8),\n",
    "        # Please note that that if image, image_2 are called via the same transform call because of the determinism\n",
    "        # they will get augmented the exact same way which is not the required case here, hence two calls are made\n",
    "        OneOf(\n",
    "            transforms=[\n",
    "                RandCoarseDropoutd(\n",
    "                    keys=[\"image_2\"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32\n",
    "                ),\n",
    "                RandCoarseDropoutd(\n",
    "                    keys=[\"image_2\"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64\n",
    "                ),\n",
    "            ]\n",
    "        ),\n",
    "        RandCoarseShuffled(keys=[\"image_2\"], prob=0.8, holes=10, spatial_size=8),\n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "check_ds = Dataset(data=train_data, transform=train_transforms)\n",
    "check_loader = DataLoader(check_ds, batch_size=1)\n",
    "check_data = first(check_loader)\n",
    "image = check_data[\"image\"][0][0]\n",
    "print(f\"image shape: {image.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e3453c4",
   "metadata": {},
   "source": [
    "##### Training Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceb5728e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training Config\n",
    "\n",
    "# Define Network ViT backbone & Loss & Optimizer\n",
    "device = torch.device(\"cuda:0\")\n",
    "model = ViTAutoEnc(\n",
    "    in_channels=1,\n",
    "    img_size=(96, 96, 96),\n",
    "    patch_size=(16, 16, 16),\n",
    "    proj_type=\"conv\",\n",
    "    hidden_size=768,\n",
    "    mlp_dim=3072,\n",
    ")\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "# Define Hyper-parameters for training loop\n",
    "max_epochs = 500\n",
    "val_interval = 2\n",
    "batch_size = 4\n",
    "lr = 1e-4\n",
    "epoch_loss_values = []\n",
    "step_loss_values = []\n",
    "epoch_cl_loss_values = []\n",
    "epoch_recon_loss_values = []\n",
    "val_loss_values = []\n",
    "best_val_loss = 1000.0\n",
    "\n",
    "recon_loss = L1Loss()\n",
    "contrastive_loss = ContrastiveLoss(temperature=0.05)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "\n",
    "\n",
    "# Define DataLoader using MONAI, CacheDataset needs to be used\n",
    "train_ds = Dataset(data=train_data, transform=train_transforms)\n",
    "train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)\n",
    "\n",
    "val_ds = Dataset(data=val_data, transform=train_transforms)\n",
    "val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60b1912d",
   "metadata": {},
   "source": [
    "##### Training loop with validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12d71ecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(max_epochs):\n",
    "    print(\"-\" * 10)\n",
    "    print(f\"epoch {epoch + 1}/{max_epochs}\")\n",
    "    model.train()\n",
    "    epoch_loss = 0\n",
    "    epoch_cl_loss = 0\n",
    "    epoch_recon_loss = 0\n",
    "    step = 0\n",
    "\n",
    "    for batch_data in train_loader:\n",
    "        step += 1\n",
    "        start_time = time.time()\n",
    "\n",
    "        inputs, inputs_2, gt_input = (\n",
    "            batch_data[\"image\"].to(device),\n",
    "            batch_data[\"image_2\"].to(device),\n",
    "            batch_data[\"gt_image\"].to(device),\n",
    "        )\n",
    "        optimizer.zero_grad()\n",
    "        outputs_v1, hidden_v1 = model(inputs)\n",
    "        outputs_v2, hidden_v2 = model(inputs_2)\n",
    "\n",
    "        flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4)\n",
    "        flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4)\n",
    "\n",
    "        r_loss = recon_loss(outputs_v1, gt_input)\n",
    "        cl_loss = contrastive_loss(flat_out_v1, flat_out_v2)\n",
    "\n",
    "        # Adjust the CL loss by Recon Loss\n",
    "        total_loss = r_loss + cl_loss * r_loss\n",
    "\n",
    "        total_loss.backward()\n",
    "        optimizer.step()\n",
    "        epoch_loss += total_loss.item()\n",
    "        step_loss_values.append(total_loss.item())\n",
    "\n",
    "        # CL & Recon Loss Storage of Value\n",
    "        epoch_cl_loss += cl_loss.item()\n",
    "        epoch_recon_loss += r_loss.item()\n",
    "\n",
    "        end_time = time.time()\n",
    "        print(\n",
    "            f\"{step}/{len(train_ds) // train_loader.batch_size}, \"\n",
    "            f\"train_loss: {total_loss.item():.4f}, \"\n",
    "            f\"time taken: {end_time-start_time}s\"\n",
    "        )\n",
    "\n",
    "    epoch_loss /= step\n",
    "    epoch_cl_loss /= step\n",
    "    epoch_recon_loss /= step\n",
    "\n",
    "    epoch_loss_values.append(epoch_loss)\n",
    "    epoch_cl_loss_values.append(epoch_cl_loss)\n",
    "    epoch_recon_loss_values.append(epoch_recon_loss)\n",
    "    print(f\"epoch {epoch + 1} average loss: {epoch_loss:.4f}\")\n",
    "\n",
    "    if epoch % val_interval == 0:\n",
    "        print(\"Entering Validation for epoch: {}\".format(epoch + 1))\n",
    "        total_val_loss = 0\n",
    "        val_step = 0\n",
    "        model.eval()\n",
    "        for val_batch in val_loader:\n",
    "            val_step += 1\n",
    "            start_time = time.time()\n",
    "            inputs, gt_input = (\n",
    "                val_batch[\"image\"].to(device),\n",
    "                val_batch[\"gt_image\"].to(device),\n",
    "            )\n",
    "            print(\"Input shape: {}\".format(inputs.shape))\n",
    "            outputs, outputs_v2 = model(inputs)\n",
    "            val_loss = recon_loss(outputs, gt_input)\n",
    "            total_val_loss += val_loss.item()\n",
    "            end_time = time.time()\n",
    "\n",
    "        total_val_loss /= val_step\n",
    "        val_loss_values.append(total_val_loss)\n",
    "        print(f\"epoch {epoch + 1} Validation avg loss: {total_val_loss:.4f}, \" f\"time taken: {end_time-start_time}s\")\n",
    "\n",
    "        if total_val_loss < best_val_loss:\n",
    "            print(f\"Saving new model based on validation loss {total_val_loss:.4f}\")\n",
    "            best_val_loss = total_val_loss\n",
    "            checkpoint = {\"epoch\": max_epochs, \"state_dict\": model.state_dict(), \"optimizer\": optimizer.state_dict()}\n",
    "            torch.save(checkpoint, os.path.join(logdir_path, \"best_model.pt\"))\n",
    "\n",
    "        plt.figure(1, figsize=(8, 8))\n",
    "        plt.subplot(2, 2, 1)\n",
    "        plt.plot(epoch_loss_values)\n",
    "        plt.grid()\n",
    "        plt.title(\"Training Loss\")\n",
    "\n",
    "        plt.subplot(2, 2, 2)\n",
    "        plt.plot(val_loss_values)\n",
    "        plt.grid()\n",
    "        plt.title(\"Validation Loss\")\n",
    "\n",
    "        plt.subplot(2, 2, 3)\n",
    "        plt.plot(epoch_cl_loss_values)\n",
    "        plt.grid()\n",
    "        plt.title(\"Training Contrastive Loss\")\n",
    "\n",
    "        plt.subplot(2, 2, 4)\n",
    "        plt.plot(epoch_recon_loss_values)\n",
    "        plt.grid()\n",
    "        plt.title(\"Training Recon Loss\")\n",
    "\n",
    "        plt.savefig(os.path.join(logdir_path, \"loss_plots.png\"))\n",
    "        plt.close(1)\n",
    "\n",
    "print(\"Done\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "da3e08083059755bb70e9f8b58ba677201225f59652efa5b6b39528ae9381865"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
